from mindspore import nn
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
                                       _get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer


class Reshape(nn.Cell):
    def __init__(self, shape, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.shape = shape
        self.reshape = P.Reshape()

    def construct(self, x):
        return self.reshape(x, self.shape)


class SigmoidCrossEntropyWithLogits(nn.loss.loss._Loss):
    def __init__(self):
        super().__init__()
        self.cross_entropy = P.SigmoidCrossEntropyWithLogits()

    def construct(self, data, label):
        x = self.cross_entropy(data, label)
        return self.get_loss(x)


class GenWithLossCell(nn.Cell):
    def __init__(self, netG, netD, loss_fn, auto_prefix=True):
        super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netD = netD
        self.loss_fn = loss_fn

    def construct(self, latent_code):
        fake_data = self.netG(latent_code)
        fake_out = self.netD(fake_data)
        loss_G = self.loss_fn(fake_out, F.ones_like(fake_out))

        return loss_G


class DisWithLossCell(nn.Cell):
    def __init__(self, netG, netD, loss_fn, auto_prefix=True):
        super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netD = netD
        self.loss_fn = loss_fn

    def construct(self, real_data, latent_code):
        fake_data = self.netG(latent_code)
        real_out = self.netD(real_data)
        real_loss = self.loss_fn(real_out, F.ones_like(real_out))
        fake_out = self.netD(fake_data)
        fake_loss = self.loss_fn(fake_out, F.zeros_like(fake_out))
        loss_D = real_loss + fake_loss

        return loss_D


class TrainOneStepCell(nn.Cell):
    def __init__(
        self,
        netG,
        netD,
        optimizerG: nn.Optimizer,
        optimizerD: nn.Optimizer,
        sens=1.0,
        auto_prefix=True,
    ):
        super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix)
        self.netG = netG
        self.netG.set_grad()
        self.netG.add_flags(defer_inline=True)

        self.netD = netD
        self.netD.set_grad()
        self.netD.add_flags(defer_inline=True)

        self.weights_G = optimizerG.parameters
        self.optimizerG = optimizerG
        self.weights_D = optimizerD.parameters
        self.optimizerD = optimizerD

        self.grad = C.GradOperation(get_by_list=True, sens_param=True)

        self.sens = sens
        self.reducer_flag = False
        self.grad_reducer_G = F.identity
        self.grad_reducer_D = F.identity
        self.parallel_mode = _get_parallel_mode()
        if self.parallel_mode in (ParallelMode.DATA_PARALLEL,
                                  ParallelMode.HYBRID_PARALLEL):
            self.reducer_flag = True
        if self.reducer_flag:
            mean = _get_gradients_mean()
            degree = _get_device_num()
            self.grad_reducer_G = DistributedGradReducer(
                self.weights_G, mean, degree)
            self.grad_reducer_D = DistributedGradReducer(
                self.weights_D, mean, degree)

    def trainD(self, real_data, latent_code, loss, loss_net, grad, optimizer,
               weights, grad_reducer):
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        grads = grad(loss_net, weights)(real_data, latent_code, sens)
        grads = grad_reducer(grads)
        return F.depend(loss, optimizer(grads))

    def trainG(self, latent_code, loss, loss_net, grad, optimizer, weights,
               grad_reducer):
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        grads = grad(loss_net, weights)(latent_code, sens)
        grads = grad_reducer(grads)
        return F.depend(loss, optimizer(grads))

    def construct(self, real_data, latent_code):
        loss_D = self.netD(real_data, latent_code)
        loss_G = self.netG(latent_code)
        d_out = self.trainD(real_data, latent_code, loss_D, self.netD,
                            self.grad, self.optimizerD, self.weights_D,
                            self.grad_reducer_D)
        g_out = self.trainG(latent_code, loss_G, self.netG, self.grad,
                            self.optimizerG, self.weights_G,
                            self.grad_reducer_G)

        return d_out, g_out
